package com.ikingtech.framework.sdk.limit.mvc.filter;

import com.ikingtech.framework.sdk.core.response.R;
import com.ikingtech.framework.sdk.enums.limit.RateLimitPriorityEnum;
import com.ikingtech.framework.sdk.limit.model.RateLimitResult;
import com.ikingtech.framework.sdk.limit.config.properties.RateLimitProperties;
import com.ikingtech.framework.sdk.limit.embedded.RateLimiterService;
import com.ikingtech.framework.sdk.limit.mvc.MvcRateLimiterFactory;
import com.ikingtech.framework.sdk.limit.mvc.RateLimitInfo;
import com.ikingtech.framework.sdk.utils.Tools;
import io.undertow.server.HttpServerExchange;
import io.undertow.servlet.spec.HttpServletRequestImpl;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.jetbrains.annotations.NotNull;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.MediaType;
import org.springframework.util.AntPathMatcher;
import org.springframework.web.filter.OncePerRequestFilter;

import jakarta.servlet.FilterChain;
import jakarta.servlet.ServletException;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.List;
import java.util.Map;

/**
 * @author zhangqiang
 */
@Slf4j
@RequiredArgsConstructor
public class MvcRateLimiterFilter extends OncePerRequestFilter {

    private final RateLimitProperties rateLimitProperties;

    private final RateLimiterService rateLimiterService;

    private final AntPathMatcher antPathMatcher = new AntPathMatcher();

    @Value("${spring.application.name}")
    private String name;

    @Override
    protected void doFilterInternal(@NotNull HttpServletRequest request,
                                    @NotNull HttpServletResponse response,
                                    @NotNull FilterChain filterChain) throws ServletException, IOException {
        HttpServletRequestImpl requestImpl = (HttpServletRequestImpl) request;
        HttpServerExchange exchange = requestImpl.getExchange();
        String requestUri = exchange.getRequestURI();
        String requestType = exchange.getRequestMethod().toString();
        if (!rateLimitProperties.isEnable()) {
            filterChain.doFilter(request, response);
            return;
        }
        //覆盖模式下，只按照配置中的规则执行，注解中的规则将被忽略
        if (RateLimitPriorityEnum.COVER.equals(rateLimitProperties.getPriority())) {
            boolean propertiesResponseReturn = doRateLimitWithProperties(request, response, filterChain, requestUri, requestType);
            if (propertiesResponseReturn) {
                return;
            }
        }
        //补充模式下，对配置中的规则和注解的规则取并集，有冲突时，以配置优先（默认为补充模式）
        if (RateLimitPriorityEnum.COMPLEMENT.equals(rateLimitProperties.getPriority())) {
            boolean propertiesResponseReturn = doRateLimitWithProperties(request, response, filterChain, requestUri, requestType);
            if (propertiesResponseReturn) {
                return;
            }
            if (doRateLimitWithAnnotation(response, requestUri, requestType)) {
                return;
            }
        }
        filterChain.doFilter(request, response);
    }

    private boolean doRateLimitWithAnnotation(HttpServletResponse response, String requestUri, String requestType) throws IOException {
        Map<String, RateLimitInfo> rateLimitInfoMap = MvcRateLimiterFactory.getRateLimitInfoMap();
        String mapKey = requestType + requestUri;
        RateLimitInfo rateLimitInfo = rateLimitInfoMap.get(mapKey);
        List<RateLimitProperties.Request> includeList = rateLimitProperties.getIncludeList();
        //如果includeList为空，并且有注解标注则限流该请求
        if (Tools.Coll.isBlank(includeList) && rateLimitInfo != null) {
            String cacheKey = getCacheKey(requestUri);
            RateLimitResult tryAcquire = rateLimiterService.tryAcquire(cacheKey, rateLimitInfo.getCount(), rateLimitInfo.getTime());
            if (Boolean.FALSE.equals(tryAcquire.getSuccess())) {
                rateLimitRequest(response, tryAcquire.getErrMsg());
                return true;
            }
        }
        return false;
    }

    private boolean doRateLimitWithProperties(HttpServletRequest request,
                                              HttpServletResponse response,
                                              FilterChain filterChain,
                                              String requestUri,
                                              String requestType) throws IOException, ServletException {
        List<RateLimitProperties.Request> excludeList = rateLimitProperties.getExcludeList();
        RateLimitProperties.Request excludeMatch = match(requestType, requestUri, excludeList);
        //如果excludeList不空，并且该请求存在于其中，则放行该请求
        if (Tools.Coll.isNotBlank(excludeList) && excludeMatch != null) {
            filterChain.doFilter(request, response);
            return true;
        }
        List<RateLimitProperties.Request> includeList = rateLimitProperties.getIncludeList();
        RateLimitProperties.Request includeMatch = match(requestType, requestUri, includeList);
        //如果includeList不为空并且该请求存在于其中，那么对该请求执行限流操作
        if ((Tools.Coll.isNotBlank(includeList) && includeMatch != null)) {
            Integer time = getDefaultValue(includeMatch.getTime(), rateLimitProperties.getTime());
            Integer count = getDefaultValue(includeMatch.getCount(), rateLimitProperties.getCount());
            String cacheKey = getCacheKey(requestUri);
            RateLimitResult tryAcquire = rateLimiterService.tryAcquire(cacheKey, count, time);
            if (Boolean.FALSE.equals(tryAcquire.getSuccess())) {
                rateLimitRequest(response, tryAcquire.getErrMsg());
                return true;
            }
        }
        return false;
    }

    private Integer getDefaultValue(Integer propertiesRequestValue, Integer propertiesDefaultValue) {
        if (propertiesRequestValue == null || propertiesRequestValue <= 0) {
            return propertiesDefaultValue;
        }
        return propertiesRequestValue;
    }

    private void rateLimitRequest(HttpServletResponse response, String cause) throws IOException {
        response.setCharacterEncoding("utf-8");
        response.setContentType(MediaType.APPLICATION_JSON_VALUE);
        PrintWriter writer = response.getWriter();
        writer.write(Tools.Json.toJsonStr(R.failed(cause)));
    }

    private RateLimitProperties.Request match(String requestMethod, String requestUri, List<RateLimitProperties.Request> list) {
        if (Tools.Coll.isNotBlank(list)) {
            for (RateLimitProperties.Request request : list) {
                if (requestMethod.equalsIgnoreCase(request.getType()) && request.getUri() != null
                        && antPathMatcher.match(request.getUri(), requestUri)) {
                    return request;
                }
            }
        }
        return null;
    }

    private String getCacheKey(String requestUri) {
        return this.rateLimitProperties.getCacheKeyPrefix() +
                Tools.Str.COLON +
                name +
                Tools.Str.COLON +
                Tools.Str.COLON +
                requestUri;
    }
}
